{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2bdba67e",
   "metadata": {},
   "source": [
    "## Description:\n",
    "这个笔记本主要是先把YoutubeDNN召回跑起来， 采用的数据集是data_process下面的train_data.csv， 这个是用户历史行为数据+用户画像数据， 模型的话先调用包实现， 这里面主要完成的步骤是：\n",
    "1. 导入数据， 划分出测试集来， 最后一天用户id进行评估\n",
    "2. 用训练集的数据构造YoutubeDNN的数据集出来， 这里采用滑窗的方式构造， 但由于有些序列很长，尝试通过滑动步长优化\n",
    "3. 构造YoutubeDNN的输入， 由于是调用deepmatch的YoutubeDNN，所以需要把输入特殊构造\n",
    "4. 建立YoutubeDNN模型训练\n",
    "5. 从YoutubeDNN中拿到用户embedding和doc的embedding，保存起来\n",
    "6. 用第11天的测试集，进行评估，看看YoutubeDNN的召回效果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1a2d4f02",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:\n",
      "DeepCTR version 0.9.0 detected. Your version is 0.8.2.\n",
      "Use `pip install -U deepctr` to upgrade.Changelog: https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.0\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import pickle\n",
    "import random\n",
    "from datetime import datetime\n",
    "import collections\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "# 从utils里面导入函数\n",
    "from utils import gen_data_set, gen_model_input, train_youtube_model\n",
    "from utils import get_embeddings, get_youtube_recall_res\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "#os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"  # 1显示所有信息， 2显示error和warnings, 3显示error"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c647e36",
   "metadata": {},
   "source": [
    "## 1. 导入数据，划分数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bb44f72f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = 'data_process'\n",
    "data = pd.read_csv(os.path.join(data_path, 'train_data.csv'), index_col=0, parse_dates=['expo_time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d0e273ce",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3939989, 18)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "54970877",
   "metadata": {},
   "outputs": [],
   "source": [
    "latest_time = data['expo_time'].max()\n",
    "# 添加example_age字段\n",
    "data['example_age'] = (pd.to_datetime(latest_time) - data['expo_time'])\n",
    "# 转换成小时的形式， 上面两式子相减是pandas的timedelta类型， 只有days和seconds属性\n",
    "data['example_age'] = data['example_age'].apply(lambda x: x.days*24 + x.seconds//3600)\n",
    "\n",
    "# 归一化\n",
    "minmax = MinMaxScaler()\n",
    "data['example_age'] = minmax.fit_transform(data['example_age'].values.reshape(-1, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "054ea99a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>article_id</th>\n",
       "      <th>expo_time</th>\n",
       "      <th>net_status</th>\n",
       "      <th>flush_nums</th>\n",
       "      <th>exop_position</th>\n",
       "      <th>click</th>\n",
       "      <th>duration</th>\n",
       "      <th>device</th>\n",
       "      <th>os</th>\n",
       "      <th>province</th>\n",
       "      <th>city</th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "      <th>ctime</th>\n",
       "      <th>img_num</th>\n",
       "      <th>cat_1</th>\n",
       "      <th>cat_2</th>\n",
       "      <th>example_age</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1000541010</td>\n",
       "      <td>464467760</td>\n",
       "      <td>2021-06-30 09:57:14</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>13</td>\n",
       "      <td>1</td>\n",
       "      <td>28</td>\n",
       "      <td>V2054A</td>\n",
       "      <td>Android</td>\n",
       "      <td>上海</td>\n",
       "      <td>上海</td>\n",
       "      <td>A_0_24</td>\n",
       "      <td>female</td>\n",
       "      <td>2021-06-29 14:46:43</td>\n",
       "      <td>3.0</td>\n",
       "      <td>娱乐</td>\n",
       "      <td>娱乐/港台明星</td>\n",
       "      <td>0.946108</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1000541010</td>\n",
       "      <td>463850913</td>\n",
       "      <td>2021-06-30 09:57:14</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>15</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>V2054A</td>\n",
       "      <td>Android</td>\n",
       "      <td>上海</td>\n",
       "      <td>上海</td>\n",
       "      <td>A_0_24</td>\n",
       "      <td>female</td>\n",
       "      <td>2021-06-27 22:29:13</td>\n",
       "      <td>11.0</td>\n",
       "      <td>时尚</td>\n",
       "      <td>时尚/女性时尚</td>\n",
       "      <td>0.946108</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1000541010</td>\n",
       "      <td>464022440</td>\n",
       "      <td>2021-06-30 09:57:14</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>17</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>V2054A</td>\n",
       "      <td>Android</td>\n",
       "      <td>上海</td>\n",
       "      <td>上海</td>\n",
       "      <td>A_0_24</td>\n",
       "      <td>female</td>\n",
       "      <td>2021-06-28 12:22:54</td>\n",
       "      <td>7.0</td>\n",
       "      <td>农村</td>\n",
       "      <td>农村/农业资讯</td>\n",
       "      <td>0.946108</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1000541010</td>\n",
       "      <td>464586545</td>\n",
       "      <td>2021-06-30 09:58:31</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>V2054A</td>\n",
       "      <td>Android</td>\n",
       "      <td>上海</td>\n",
       "      <td>上海</td>\n",
       "      <td>A_0_24</td>\n",
       "      <td>female</td>\n",
       "      <td>2021-06-29 13:25:06</td>\n",
       "      <td>5.0</td>\n",
       "      <td>娱乐</td>\n",
       "      <td>娱乐/港台明星</td>\n",
       "      <td>0.946108</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1000541010</td>\n",
       "      <td>465352885</td>\n",
       "      <td>2021-07-03 18:13:03</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>18</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>V2054A</td>\n",
       "      <td>Android</td>\n",
       "      <td>上海</td>\n",
       "      <td>上海</td>\n",
       "      <td>A_0_24</td>\n",
       "      <td>female</td>\n",
       "      <td>2021-07-02 10:43:51</td>\n",
       "      <td>18.0</td>\n",
       "      <td>娱乐</td>\n",
       "      <td>娱乐/港台明星</td>\n",
       "      <td>0.461078</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      user_id  article_id           expo_time  net_status  flush_nums  \\\n",
       "0  1000541010   464467760 2021-06-30 09:57:14           2           0   \n",
       "1  1000541010   463850913 2021-06-30 09:57:14           2           0   \n",
       "2  1000541010   464022440 2021-06-30 09:57:14           2           0   \n",
       "3  1000541010   464586545 2021-06-30 09:58:31           2           1   \n",
       "4  1000541010   465352885 2021-07-03 18:13:03           5           0   \n",
       "\n",
       "   exop_position  click  duration  device       os province city     age  \\\n",
       "0             13      1        28  V2054A  Android       上海   上海  A_0_24   \n",
       "1             15      0         0  V2054A  Android       上海   上海  A_0_24   \n",
       "2             17      0         0  V2054A  Android       上海   上海  A_0_24   \n",
       "3             20      0         0  V2054A  Android       上海   上海  A_0_24   \n",
       "4             18      0         0  V2054A  Android       上海   上海  A_0_24   \n",
       "\n",
       "   gender                ctime img_num cat_1    cat_2  example_age  \n",
       "0  female  2021-06-29 14:46:43     3.0    娱乐  娱乐/港台明星     0.946108  \n",
       "1  female  2021-06-27 22:29:13    11.0    时尚  时尚/女性时尚     0.946108  \n",
       "2  female  2021-06-28 12:22:54     7.0    农村  农村/农业资讯     0.946108  \n",
       "3  female  2021-06-29 13:25:06     5.0    娱乐  娱乐/港台明星     0.946108  \n",
       "4  female  2021-07-02 10:43:51    18.0    娱乐  娱乐/港台明星     0.461078  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8f8197be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 选择出需要用到的列\n",
    "use_cols = ['user_id', 'article_id', 'expo_time', 'net_status', 'exop_position', 'duration', 'device', 'city', 'age', 'gender', 'example_age', 'click']\n",
    "data_new = data[use_cols]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f58133ec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1158,)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_new['exop_position'].unique().shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57b094a2",
   "metadata": {},
   "source": [
    "## 划分测试集和训练集\n",
    "* 训练集， 每个用户的历史点击，去掉最后一次\n",
    "* 测试集， 每个用户的最后一次点击"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "37dee164",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 按照用户分组，然后把最后一个item拿出来\n",
    "click_df = data_new[data_new['click']==1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ac79a1a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hist_and_last_click(all_click):\n",
    "    all_click = all_click.sort_values(by=['user_id', 'expo_time'])\n",
    "    click_last_df = all_click.groupby('user_id').tail(1)\n",
    "    \n",
    "    # 如果用户只有一个点击，hist为空了，会导致训练的时候这个用户不可见，此时默认泄露一下\n",
    "    def hist_func(user_df):\n",
    "        if len(user_df) == 1:\n",
    "            return user_df\n",
    "        else:\n",
    "            return user_df[:-1]\n",
    "\n",
    "    click_hist_df = all_click.groupby('user_id').apply(hist_func).reset_index(drop=True)\n",
    "\n",
    "    return click_hist_df, click_last_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5d111df1",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_click_hist_df, user_click_last_df = get_hist_and_last_click(click_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f026c11",
   "metadata": {},
   "source": [
    "## 2. YoutubeDNN召回"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3e3a6458",
   "metadata": {},
   "outputs": [],
   "source": [
    "def youtubednn_recall(data, topk=200, embedding_dim=8, his_seq_maxlen=50, negsample=0,\n",
    "                      batch_size=64, epochs=1, verbose=1, validation_split=0.0):\n",
    "    \"\"\"通过YouTubeDNN模型，计算用户向量和文章向量\n",
    "    param: data: 用户日志数据\n",
    "    topk: 对于每个用户，召回多少篇文章\n",
    "    \"\"\"\n",
    "    user_id_raw = data[['user_id']].drop_duplicates('user_id')\n",
    "    doc_id_raw = data[['article_id']].drop_duplicates('article_id')\n",
    "    \n",
    "    # 类别数据编码   \n",
    "    base_features = ['user_id', 'article_id', 'city', 'age', 'gender']\n",
    "    feature_max_idx = {}\n",
    "    for f in base_features:\n",
    "        lbe = LabelEncoder()\n",
    "        data[f] = lbe.fit_transform(data[f])\n",
    "        feature_max_idx[f] = data[f].max() + 1\n",
    "        \n",
    "    # 构建用户id词典和doc的id词典，方便从用户idx找到原始的id\n",
    "    user_id_enc = data[['user_id']].drop_duplicates('user_id')\n",
    "    doc_id_enc = data[['article_id']].drop_duplicates('article_id')\n",
    "    user_idx_2_rawid = dict(zip(user_id_enc['user_id'], user_id_raw['user_id']))\n",
    "    doc_idx_2_rawid = dict(zip(doc_id_enc['article_id'], doc_id_raw['article_id']))\n",
    "    \n",
    "    # 保存下每篇文章的被点击数量， 方便后面高热文章的打压\n",
    "    doc_clicked_count_df = data.groupby('article_id')['click'].apply(lambda x: x.count()).reset_index()\n",
    "    doc_clicked_count_dict = dict(zip(doc_clicked_count_df['article_id'], doc_clicked_count_df['click']))\n",
    "\n",
    "    train_set, test_set = gen_data_set(data, doc_clicked_count_dict, negsample, control_users=False)\n",
    "    \n",
    "    # 构造youtubeDNN模型的输入\n",
    "    train_model_input, train_label = gen_model_input(train_set, his_seq_maxlen)\n",
    "    test_model_input, test_label = gen_model_input(test_set, his_seq_maxlen)\n",
    "    \n",
    "    # 构建模型并完成训练\n",
    "    model = train_youtube_model(train_model_input, train_label, embedding_dim, feature_max_idx, his_seq_maxlen, batch_size, epochs, verbose, validation_split)\n",
    "    \n",
    "    # 获得用户embedding和doc的embedding， 并进行保存\n",
    "    user_embs, doc_embs = get_embeddings(model, test_model_input, user_idx_2_rawid, doc_idx_2_rawid)\n",
    "    \n",
    "    # 对每个用户，拿到召回结果并返回回来\n",
    "    user_recall_doc_dict = get_youtube_recall_res(user_embs, doc_embs, user_idx_2_rawid, doc_idx_2_rawid, topk)\n",
    "    \n",
    "    return user_recall_doc_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "54421115",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 20000/20000 [50:14<00:00,  6.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 1904804 samples\n",
      "1904804/1904804 [==============================] - 3974s 2ms/sample - loss: 0.7390\n"
     ]
    }
   ],
   "source": [
    "user_recall_doc_dict = youtubednn_recall(user_click_hist_df, negsample=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "76ad32f5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "20000"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(user_recall_doc_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fba0fa4b",
   "metadata": {},
   "source": [
    "## 评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "3ad17002",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20000, 12)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "user_click_last_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ec0b97e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 依次评估召回的前10, 20, 30, 40, 50个文章中的击中率\n",
    "def metrics_recall(user_recall_items_dict, trn_last_click_df, topk=100):\n",
    "    last_click_item_dict = dict(zip(trn_last_click_df['user_id'], trn_last_click_df['article_id']))\n",
    "    user_num = len(user_recall_items_dict)\n",
    "    \n",
    "    for k in range(50, topk+1, 50):\n",
    "        hit_num = 0\n",
    "        for user, item_list in user_recall_items_dict.items():\n",
    "            if user in last_click_item_dict:\n",
    "                # 获取前k个召回的结果\n",
    "                tmp_recall_items = [x[0] for x in user_recall_items_dict[user][:k]]\n",
    "                if last_click_item_dict[user] in set(tmp_recall_items):\n",
    "                    hit_num += 1\n",
    "        \n",
    "        hit_rate = round(hit_num * 1.0 / user_num, 5)\n",
    "        print(' topk: ', k, ' : ', 'hit_num: ', hit_num, 'hit_rate: ', hit_rate, 'user_num : ', user_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "52cc0435",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " topk:  50  :  hit_num:  22 hit_rate:  0.0011 user_num :  20000\n",
      " topk:  100  :  hit_num:  37 hit_rate:  0.00185 user_num :  20000\n",
      " topk:  150  :  hit_num:  68 hit_rate:  0.0034 user_num :  20000\n",
      " topk:  200  :  hit_num:  99 hit_rate:  0.00495 user_num :  20000\n"
     ]
    }
   ],
   "source": [
    "metrics_recall(user_recall_doc_dict, user_click_last_df, topk=200)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
